--- ../../swvox/swvox/renderer.py	2024-02-14 05:16:57.541407536 -0500
+++ ../google/swvox/swvox/renderer.py	2024-02-13 10:46:24.576645015 -0500
@@ -31,7 +31,7 @@
 from collections import namedtuple
 from warnings import warn
 
-from svox.helpers import _get_c_extension, LocalIndex, DataFormat
+from swvox.helpers import _get_c_extension, LocalIndex, DataFormat, WaveletType
 
 NDCConfig = namedtuple('NDCConfig', ["width", "height", "focal"])
 Rays = namedtuple('Rays', ["origins", "dirs", "viewdirs"])
@@ -39,14 +39,21 @@
 _C = _get_c_extension()
 
 def _rays_spec_from_rays(rays):
-    spec = _C.RaysSpec()
+    spec = _C.RaysSpecSwvox()
     spec.origins = rays.origins
     spec.dirs = rays.dirs
     spec.vdirs = rays.viewdirs
+    # TODO(machc): this assumes dirs is unit-norm!
+    # TODO(machc): better to have a single scalar offset on t?
+    # spec.random_offset = (torch.rand_like(rays.dirs[..., [0]])
+    #                       - 0.5) / 2.0**9 * rays.dirs
+    # zero this out until we fix the step sizes
+    spec.random_offset = torch.zeros_like(rays.dirs)
+
     return spec
 
 def _make_camera_spec(c2w, width, height, fx, fy):
-    spec = _C.CameraSpec()
+    spec = _C.CameraSpecSwvox()
     spec.c2w = c2w
     spec.width = width
     spec.height = height
@@ -71,6 +78,7 @@
             ), None, None, None
         return None, None, None, None
 
+
 class _VolumeRenderImageFunction(autograd.Function):
     @staticmethod
     def forward(ctx, data, tree, cam, opt):
@@ -123,6 +131,11 @@
             max_comp : int=-1,
             density_softplus : bool=False,
             rgb_padding : float=0.0,
+            accumulate: bool=False,
+            accumulate_sigma: bool=False,
+            piecewise_linear: bool=False,
+            render_distance: bool=False,
+            sigma_penalty: float=0.00,
         ):
         """
         Construct volume renderer associated with given N^3 tree.
@@ -163,7 +176,7 @@
                         Please note the padding will NOT be compatible with volrend,
                         although most likely the effect is very small.
                         0.001 is a reasonable value to try.
-
+        :param accumulate: accumulate values from root to leaf.
         """
         super().__init__()
         self.tree = tree
@@ -174,6 +187,15 @@
         self.max_comp = max_comp
         self.density_softplus = density_softplus
         self.rgb_padding = rgb_padding
+        self.accumulate = accumulate
+        self.piecewise_linear = piecewise_linear
+        
+        assert (accumulate and tree.lowpass_depth >= 0) or (tree.lowpass_depth == -1), \
+            "If using accumulation, lowpass_depth must be >=0 and if not using accumulation lowpass_depth must be -1"
+        
+        self.accumulate_sigma = accumulate_sigma
+        self.render_distance = render_distance
+        self.sigma_penalty = sigma_penalty
         if isinstance(tree.data_format, DataFormat):
             self._data_format = None
         else:
@@ -186,11 +208,11 @@
                 self._data_format = DataFormat(f"SH{(ddim - 1) // 3}")
         self.tree._weight_accum = None
 
-    def forward(self, rays : Rays, cuda=True, fast=False):
+    def forward(self, rays : Rays, cuda=True, fast=False, backward_absolute_values=False):
         """
         Render a batch of rays. Differentiable.
 
-        :param rays: namedtuple :code:`svox.Rays` of origins
+        :param rays: namedtuple :code:`swvox.Rays` of origins
                      :code:`(B, 3)`, dirs :code:`(B, 3):, viewdirs :code:`(B, 3)`
         :param cuda: whether to use CUDA kernel if available. If false,
                      uses only PyTorch version. *Only True supported right now*
@@ -206,7 +228,13 @@
             assert self.data_format.format in [DataFormat.RGBA, DataFormat.SH], \
                  "Unsupported data format for slow volume rendering"
             warn("Using slow volume rendering, should only be used for debugging")
+            assert not backward_absolute_values, "Only implemented with CUDA"
             def dda_unit(cen, invdir):
+                # computes the aabb ray intersection. for a voxel comprised in (0,0,0), (1,1,1)
+                # does not return negative values, it if's inside the voxel, tmin is 0.
+                # the intersections are in t_min * 1/invdir (if t_min > 0) else it inside of the voxel.
+                # 
+                
                 """
                 voxel aabb ray tracing step
                 :param cen: jnp.ndarray [B, 3] center
@@ -232,7 +260,7 @@
 
             sh_mult = None
             if self.data_format.format == DataFormat.SH:
-                from svox import sh
+                from swvox import sh
                 sh_order = int(self.data_format.basis_dim ** 0.5) - 1
                 sh_mult = sh.eval_sh_bases(sh_order, viewdirs)[:, None]
 
@@ -244,8 +272,11 @@
             good_indices = torch.arange(B, device=origins.device)
             delta_scale = (dirs / self.tree.invradius[None]).norm(dim=1)
             while good_indices.numel() > 0:
+                # This is for each sample.
                 pos = origins + t[:, None] * dirs
+                # Fetches the voxels containing each point along this depth.
                 treeview = self.tree[LocalIndex(pos)]
+                # Could this be SH?
                 rgba = treeview.values
                 cube_sz = treeview.lengths_local
                 pos_t = (pos - treeview.corners_local) / cube_sz[:, None]
@@ -253,7 +284,9 @@
 
                 subcube_tmin, subcube_tmax = dda_unit(pos_t, invdirs)
 
+                # delta is length inside the cube + step_size
                 delta_t = (subcube_tmax - subcube_tmin) * cube_sz + self.step_size
+                # Integrating density.
                 att = torch.exp(- delta_t * torch.relu(rgba[..., -1]) * delta_scale[good_indices])
                 weight = light_intensity[good_indices] * (1.0 - att)
                 rgb = rgba[:, :-1]
@@ -284,8 +317,8 @@
             self.tree.data,
             self.tree._spec(),
             _rays_spec_from_rays(rays),
-            self._get_options(fast)
-        )
+            self._get_options(fast, backward_absolute_values),
+            )
 
     def render_persp(self, c2w, width=800, height=800, fx=1111.111, fy=None,
             cuda=True, fast=False):
@@ -336,7 +369,7 @@
         The tree's rendered output dimension (rgb_dim) cannot
         be greater than 4 (this is almost always true, don't need to worry).
 
-        :param rays: namedtuple :code:`svox.Rays` of origins
+        :param rays: namedtuple :code:`swvox.Rays` of origins
                      :code:`(B, 3)`, dirs :code:`(B, 3):, viewdirs :code:`(B, 3)`
         :param colors: torch.Tensor :code:`(B, 3)` reference colors
 
@@ -385,11 +418,10 @@
             colors)
 
     @staticmethod
-    def persp_rays(c2w, width=800, height=800, fx=1111.111, fy=None):
+    def persp_rays(c2w, width=800, height=800, fx=1111.111, fy=None, ndc_config=None, rays_at_center_pixel=False):
         """
         Generate perspective camera rays in row major order, then
         usable for renderer's forward method.
-        *NDC is not supported currently.*
 
         :param c2w: torch.Tensor (3, 4) or (4, 4) camera pose matrix (c2w)
         :param width: int output image width
@@ -406,8 +438,8 @@
             fy = fx
         origins = c2w[None, :3, 3].expand(height * width, -1).contiguous()
         yy, xx = torch.meshgrid(
-            torch.arange(height, dtype=torch.float64, device=c2w.device),
-            torch.arange(width, dtype=torch.float64, device=c2w.device),
+            torch.arange(height, dtype=torch.float64, device=c2w.device) + (0.5 if rays_at_center_pixel else 0.0),
+            torch.arange(width, dtype=torch.float64, device=c2w.device) + (0.5 if rays_at_center_pixel else 0.0),
         )
         xx = (xx - width * 0.5) / float(fx)
         yy = (yy - height * 0.5) / float(fy)
@@ -419,21 +451,34 @@
         dirs = torch.matmul(c2w[None, :3, :3].double(), dirs[..., None])[..., 0].float()
         vdirs = dirs
 
-        return Rays(
-            origins=origins,
-            dirs=dirs,
-            viewdirs=vdirs
-        )
+        if not ndc_config is None:
+            ndc_origins, ndc_directions = convert_to_ndc(
+                origins, dirs, fx, width, height
+            )
+            ndc_directions /= torch.norm(ndc_directions, dim=-1, keepdim=True)
+
+            return Rays(
+                origins=ndc_origins,
+                dirs=ndc_directions,
+                viewdirs=dirs,
+            )
+        else:
+            return Rays(
+                origins=origins,
+                dirs=dirs,
+                viewdirs=vdirs
+            )
+
 
     @property
     def data_format(self):
         return self._data_format or self.tree.data_format
 
-    def _get_options(self, fast=False):
+    def _get_options(self, fast=False, backward_absolute_values=False):
         """
-        Make RenderOptions struct to send to C++
+        Make RenderOptionsSwvox struct to send to C++
         """
-        opts = _C.RenderOptions()
+        opts = _C.RenderOptionsSwvox()
         opts.step_size = self.step_size
         opts.background_brightness = self.background_brightness
 
@@ -442,6 +487,8 @@
         opts.min_comp = self.min_comp
         opts.max_comp = self.max_comp
 
+        assert opts.basis_dim > 0, "At least should be 1 for RGBA"
+
         if self.max_comp < 0:
             opts.max_comp += opts.basis_dim
 
@@ -456,6 +503,7 @@
             opts.ndc_width = -1
 
         if fast:
+            print("Warning: user fast mode! This may lead to loss of accuracy. Not to use during training!!!!.")
             opts.sigma_thresh = 1e-2
             opts.stop_thresh = 1e-2
         else:
@@ -466,4 +514,34 @@
             opts.sigma_thresh = self.sigma_thresh
         if hasattr(self, "stop_thresh"):
             opts.stop_thresh = self.stop_thresh
+            
+        # opts: add accumulative and wavelet type
+        opts.accumulate = self.accumulate
+        opts.accumulate_sigma = self.accumulate_sigma
+        
+        opts.piecewise_linear = self.piecewise_linear
+        
+        opts.wavelet_type = self.tree.wavelet_type.wavelet_type
+        opts.lowpass_depth = self.tree.lowpass_depth
+        
+        # TODO: check that all leaves have minimum depth
+        # This check may make the rendering slow, as we need to instantiate self.depths
+        
+        if self.tree.wavelet_type.wavelet_type not in [WaveletType.CONSTANT, 
+                                                       WaveletType.TRILINEAR,
+                                                       WaveletType.SIDE]:
+            assert opts.accumulate, "wavelet type is not constant, so accumulate must be true to compute the full pyramid if you are not debugging!"
+                
+        opts.eval_wavelet_integral = self.tree.eval_wavelet_integral
+        opts.linear_color = self.tree.linear_color
+        opts.wavelet_sigma = self.tree.wavelet_sigma
+        
+        
+        # for debugging purposes, render depth along the ray instead of the color
+        opts.render_distance = self.render_distance
+        opts.sigma_penalty = self.sigma_penalty
+        
+        # to find what nodes to refine, we aggregate the absolute values, not the values themselves so that -1,1 on a same node don't cancel
+        opts.backward_absolute_values = backward_absolute_values
+        
         return opts
